import matplotlib.pyplot as plt
import mindspore.dataset as ds
import mindspore.dataset.vision.py_transforms as py_trans
from mindspore.dataset.transforms.py_transforms import Compose
from PIL import Image

ds.set_seed(1)

DATA_DIR = "./datasets/cifar10/cifar-10-batches-bin"

dataset1 = ds.Cifar10Dataset(dataset_dir=DATA_DIR, num_samples=4, shuffle=True)

trans_list = [
    py_trans.ToPIL(),
    py_trans.Resize(size=([150, 150]))
    , py_trans.Invert()
    , py_trans.ToTensor()
]

compose_trans = Compose(trans_list)
dataset2 = dataset1.map(operations=compose_trans, input_columns=['image'])

image_list, label_list = [], []
for data in dataset2.create_dict_iterator():
    image_list.append(data['image'])
    label_list.append(data['label'])

    print(data['image'].shape, data['label'].shape)

num_samples = len(image_list)
for i in range(num_samples):
    plt.subplot(1, len(image_list), i + 1)
    plt.imshow(image_list[i].asnumpy().transpose(1, 2, 0))
    plt.title(label_list[i].asnumpy())

plt.show()